{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JNwMxwcaa05q"
      },
      "source": [
        "# Sharp edges in Differentiable Swift\n",
        "Differentiable Swift has come a long way in terms of usability. Here is a heads-up about the parts that are still a little un-obvious. As progress continues, this guide will become smaller and smaller, and you'll be able to write differentiable code without needing special syntax.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_LTY5_lZbLMU"
      },
      "source": [
        "##Loops\n",
        "\n",
        "Loops are differentiable, there's just one detail to know about. When you write the loop, wrap the bit where you specify what you're looping over in `withoutDerivative(at:)`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "mvBQ2eCfTDrj"
      },
      "outputs": [],
      "source": [
        "var a: [Float] = [1,2,3]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "to30MPP2TJqi"
      },
      "source": [
        "for example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "9NRXNazqTMqV"
      },
      "outputs": [],
      "source": [
        "for _ in a.indices \n",
        "{}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wPrrZ7qUTUMG"
      },
      "source": [
        "becomes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "zzwe0d8JTVN7"
      },
      "outputs": [],
      "source": [
        "for _ in withoutDerivative(at: a.indices) \n",
        "{}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UITuC4g9TYEp"
      },
      "source": [
        "or:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "YRbV8eqATcnG"
      },
      "outputs": [],
      "source": [
        "for _ in 0..<a.count \n",
        "{}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X5O0DPm6Tfxl"
      },
      "source": [
        "becomes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "n5c0Dsc3TllX"
      },
      "outputs": [],
      "source": [
        "for _ in 0..<withoutDerivative(at: a.count) \n",
        "{}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BJpcKJbTTvTC"
      },
      "source": [
        "This is necessary because the `Array.count` member doesn't contribute to the derivative with respect to the array. Only the actual elements in the array contribute to the derivative.\n",
        "\n",
        "If you've got a loop where you manually use an integer as the upper bound, there's no need to use `withoutDerivative(at:)`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "EDL4ykoZTwcl"
      },
      "outputs": [],
      "source": [
        "let iterations: Int = 10\n",
        "for _ in 0..<iterations {} //this is fine as-is."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GBuKyDDVcb8M"
      },
      "source": [
        "##Map and Reduce\n",
        "`map` and `reduce` have special differentiable versions that work exactly like what you're used to:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "vzZl_P6-W_nD"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "aPlusOne [2.0, 3.0, 4.0]\r\n",
            "aSum 6.0\r\n"
          ]
        }
      ],
      "source": [
        "a = [1,2,3]\n",
        "let aPlusOne = a.differentiableMap {$0 + 1}\n",
        "let aSum = a.differentiableReduce(0, +)\n",
        "print(\"aPlusOne\", aPlusOne)\n",
        "print(\"aSum\", aSum)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qhHxI3xwck9j"
      },
      "source": [
        "##Array subscript sets\n",
        "Array subscript sets (`array[0] = 0`) aren't differentiable out of the box, but you can paste this extension:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "vj5XDwl0XEGi"
      },
      "outputs": [],
      "source": [
        "extension Array where Element: Differentiable {\n",
        "    @differentiable(where Element: Differentiable)\n",
        "    mutating func updated(at index: Int, with newValue: Element) {\n",
        "        self[index] = newValue\n",
        "    }\n",
        " \n",
        "    @derivative(of: updated)\n",
        "    mutating func vjpUpdated(at index: Int, with newValue: Element)\n",
        "      -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector))\n",
        "    {\n",
        "        self.updated(at: index, with: newValue)\n",
        "        return ((), { v in\n",
        "            let dElement = v[index]\n",
        "            v.base[index] = .zero\n",
        "            return dElement\n",
        "        })\n",
        "    }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mCkdO-F8XLFo"
      },
      "source": [
        "and then the workaround syntax is like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "GxZnhZGdXMm0"
      },
      "outputs": [],
      "source": [
        "var b: [Float] = [1,2,3]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qC5d-l3nXPxl"
      },
      "source": [
        "instead of this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "5YNUPyS3XUQ-"
      },
      "outputs": [],
      "source": [
        "b[0] = 17"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G1sCDA9zXWSA"
      },
      "source": [
        "write this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "ze-zTQP-XbN8"
      },
      "outputs": [],
      "source": [
        "b.updated(at: 0, with: 17)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bjDuPALzfKQC"
      },
      "source": [
        "Let's make sure it works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "uKTN_ET6fNUc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(value: 3.0, gradient: [1.0])\r\n"
          ]
        }
      ],
      "source": [
        "func plusOne(array: [Float]) -> Float{\n",
        "  var array = array\n",
        "  array.updated(at: 0, with: array[0] + 1)\n",
        "  return array[0]\n",
        "}\n",
        "\n",
        "let plusOneValAndGrad = valueWithGradient(at: [2], in: plusOne)\n",
        "print(plusOneValAndGrad)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6-xCTZXNXf5c"
      },
      "source": [
        "The error you'll get without this workaround is `Differentiation of coroutine calls is not yet supported`.\n",
        "Here is the link to see progress on making this workaround unnecessary: https://bugs.swift.org/browse/TF-1277 (it talks about Array.subscript._modify, which is what's called behind the scenes when you do an array subscript set)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nU5-Mheme8Aj"
      },
      "source": [
        "##`Float` <-> `Double` conversions\n",
        "If you're switching between `Float` and `Double`, their constructors aren't already differentiable. Here's a function that will let you go from a `Float` to a `Double` differentiably.\n",
        "\n",
        "(Switch `Float` and `Double` in the below code, and you've got a function that converts from `Double` to `Float`.)\n",
        "\n",
        "You can make similar converters for any other real Numeric types."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "vc0eakpEYE8B"
      },
      "outputs": [],
      "source": [
        "@differentiable\n",
        "func convertToDouble(_ a: Float) -> Double {\n",
        "    return Double(a)\n",
        "}\n",
        " \n",
        "@derivative(of: convertToDouble)\n",
        "func convertToDoubleVJP(_ a: Float) -> (value: Double, pullback: (Double) -> Float) {\n",
        "    func pullback(_ v: Double) -> Float{\n",
        "        return Float(v)\n",
        "    }\n",
        "    return (value: Double(a), pullback: pullback)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W9Tup97yfu-x"
      },
      "source": [
        "Here's an example usage:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "pNeh1vHWfyOI"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "grad 2.0\r\n",
            "type of input: Float\r\n",
            "type of output: Double\r\n",
            "type of gradient: Float\r\n"
          ]
        }
      ],
      "source": [
        "@differentiable\n",
        "func timesTwo(a: Float) -> Double {\n",
        "  return convertToDouble(a * 2)\n",
        "}\n",
        "let input: Float = 3\n",
        "let valAndGrad = valueWithGradient(at: input, in: timesTwo)\n",
        "print(\"grad\", valAndGrad.gradient)\n",
        "print(\"type of input:\", type(of: input))\n",
        "print(\"type of output:\", type(of: valAndGrad.value))\n",
        "print(\"type of gradient:\", type(of: valAndGrad.gradient))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hR-ED1-Jf1aH"
      },
      "source": [
        "##Transcendental and other functions (sin, cos, abs, max)\n",
        "A lot of transcendentals and other common built-in functions have already been made differentiable for `Float` and `Double`. There are fewer for `Double` than `Float`. Some aren't available for either. So here are a few manual derivative definitions to give you the idea of how to make what you need, in case it isn't already provided:\n",
        "\n",
        "pow (see [link](https://www.wolframalpha.com/input/?i=partial+derivatives+of+f%28x%2Cy%29+%3D+x%5Ey) for derivative explanation)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "JPXr_xtwYOV9"
      },
      "outputs": [],
      "source": [
        "import Foundation\n",
        "\n",
        "@usableFromInline\n",
        "@derivative(of: pow) \n",
        "func powVJP(_ base: Double, _ exponent: Double) -> (value: Double, pullback: (Double) -> (Double, Double)) {\n",
        "    let output: Double = pow(base, exponent)\n",
        "    func pullback(_ vector: Double) -> (Double, Double) {\n",
        "        let baseDerivative = vector * (exponent * pow(base, exponent - 1))\n",
        "        let exponentDerivative = vector * output * log(base)\n",
        "        return (baseDerivative, exponentDerivative)\n",
        "    }\n",
        " \n",
        "    return (value: output, pullback: pullback)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CxwAtO-vYPp0"
      },
      "source": [
        "max"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "P_cqKUe1YWQz"
      },
      "outputs": [],
      "source": [
        "@usableFromInline\n",
        "@derivative(of: max)\n",
        "func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (value: T, pullback: (T.TangentVector)\n",
        "  -> (T.TangentVector, T.TangentVector))\n",
        "{\n",
        "    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {\n",
        "        if x < y {\n",
        "            return (.zero, v)\n",
        "        } else {\n",
        "            return (v, .zero)\n",
        "        }\n",
        "    }\n",
        "    return (value: max(x, y), pullback: pullback)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fwwzgOnNYXUR"
      },
      "source": [
        "abs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "iNCladNcYaio"
      },
      "outputs": [],
      "source": [
        "@usableFromInline\n",
        "@derivative(of: abs)\n",
        "func absVJP<T: Comparable & SignedNumeric & Differentiable>(_ x: T)\n",
        "  -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)\n",
        "{\n",
        "    func pullback(_ v: T.TangentVector) -> T.TangentVector{\n",
        "        if x < 0 {\n",
        "            return .zero - v\n",
        "        }\n",
        "        else {\n",
        "            return v\n",
        "        }\n",
        "    }\n",
        "    return (value: abs(x), pullback: pullback)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OnLo4C7fYeK6"
      },
      "source": [
        "sqrt (see [link](https://www.wolframalpha.com/input/?i=partial+derivative+of+f%28x%29+%3D+sqrt%28x%29) for derivative explanation)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "bIO0M5ONYiRD"
      },
      "outputs": [],
      "source": [
        "@usableFromInline\n",
        "@derivative(of: sqrt) \n",
        "func sqrtVJP(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {\n",
        "    let output = sqrt(x)\n",
        "    func pullback(_ v: Double) -> Double {\n",
        "        return v / (2 * output)\n",
        "    }\n",
        "    return (value: output, pullback: pullback)\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H45t3grVj2bx"
      },
      "source": [
        "Let's check that these work:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "liU1ZR8_j5VN"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "pow gradient:  (4.0, 2.772588722239781) which is correct\r\n",
            "max gradient:  (0.0, 1.0) which is correct\r\n",
            "abs gradient:  1.0 which is correct\r\n",
            "sqrt gradient:  0.25 which is correct\r\n"
          ]
        }
      ],
      "source": [
        "let powGrad = gradient(at: 2, 2, in: pow)\n",
        "print(\"pow gradient: \", powGrad, \"which is\", powGrad == (4.0, 2.772588722239781) ? \"correct\" : \"incorrect\")\n",
        "\n",
        "let maxGrad = gradient(at: 1, 2, in: max)\n",
        "print(\"max gradient: \", maxGrad, \"which is\", maxGrad == (0.0, 1.0) ? \"correct\" : \"incorrect\")\n",
        "\n",
        "let absGrad = gradient(at: 2, in: abs)\n",
        "print(\"abs gradient: \", absGrad, \"which is\", absGrad == 1.0 ? \"correct\" : \"incorrect\")\n",
        "\n",
        "let sqrtGrad = gradient(at: 4, in: sqrt)\n",
        "print(\"sqrt gradient: \", sqrtGrad, \"which is\", sqrtGrad == 0.25 ? \"correct\" : \"incorrect\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bjXglKrIYpUy"
      },
      "source": [
        "The compiler error that alerts you to the need for something like this is: `Expression is not differentiable. Cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZaBX70wGjlZ9"
      },
      "source": [
        "##`KeyPath` subscripting\n",
        "`KeyPath` subscripting (get or set) doesn't work out of the box, but once again, there are some extensions you can add, and then use a workaround syntax. Here it is:\n",
        "\n",
        "https://github.com/tensorflow/swift/issues/530#issuecomment-687400701\n",
        "\n",
        "This workaround is a little uglier than the others. It only works for custom objects, which must conform to Differentiable and AdditiveArithmetic. You have to add a `.tmp` member and a `.read()` function, and you use the `.tmp` member as intermediate storage when doing `KeyPath` subscript gets (there is an example in the linked code). `KeyPath` subscript sets work pretty simply with a `.write()` function.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "Swift_autodiff_sharp_edges.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Swift",
      "name": "swift"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
